Beam search 算法在文本生成中用得比较多,用于选择较优的结果(可能并不是最优的)。接下来将以seq2seq机器翻译为例来说明这个Beam search的算法思想。
在机器翻译中,beam search算法在测试的时候用的,因为在训练过程中,每一个decoder的输出是有与之对应的正确答案做参照,也就不需要beam search去加大输出的准确率。
有如下从中文到英语的翻译:
中文:
我 爱 学习,学习 使 我 快乐
英语:
I love learning, learning makes me happy
在这个测试中,中文的词汇表是{我,爱,学习,使,快乐},长度为5。英语的词汇表是{I, love, learning, make, me, happy}(全部转化为小写),长度为6。那么首先使用seq2seq中的编码器对中文序列(记这个中文序列为$X$)进行编码,得到语义向量$C$。
得到语义向量$C$后,进入解码阶段,依次翻译成目标语言。在正式解码之前,有一个参数需要设置,那就是beam search中的beam size,这个参数就相当于top-k中的k,选择前k个最有可能的结果。在本例中,我们选择beam size=3。
来看解码器的第一个输出$y_1$,在给定语义向量$C$的情况下,首先选择英语词汇表中最有可能k个单词,也就是依次选择条件概率$P(y_1|C)$前3大对应的单词,比如这里概率最大的前三个单词依次是$I$,$learning$,$happy$。
接着生成第二个输出$y_2$,在这个时候我们得到了那些东西呢,首先我们得到了编码阶段的语义向量$C$,还有第一个输出$y_1$。此时有个问题,$y_1$有三个,怎么作为这一时刻的输入呢(解码阶段需要将前一时刻的输出作为当前时刻的输入),答案就是都试下,具体做法是:
- 确定$I$为第一时刻的输出,将其作为第二时刻的输入,得到在已知$(C, I)$的条件下,各个单词作为该时刻输出的条件概率$P(y_2|C,I)$,有6个组合,每个组合的概率为$P(I|C)P(y_2|C, I)$。
- 确定$learning$为第一时刻的输出,将其作为第二时刻的输入,得到该条件下,词汇表中各个单词作为该时刻输出的条件概率$P(y_2|C, learning)$,这里同样有6种组合;
- 确定$happy$为第一时刻的输出,将其作为第二时刻的输入,得到该条件下各个单词作为输出的条件概率$P(y_2|C, happy)$,得到6种组合,概率的计算方式和前面一样。
这样就得到了18个组合,每一种组合对应一个概率值$P(y_1|C)P(y_2|C, y_1)$,接着在这18个组合中选择概率值top3的那三种组合,假设得到$I love$,$I happy$,$learning make$。
接下来要做的重复这个过程,逐步生成单词,直到遇到结束标识符停止。最后得到概率最大的那个生成序列。其概率为:
以上就是Beam search算法的思想,当beam size=1时,就变成了贪心算法。
Beam search算法也有许多改进的地方,根据最后的概率公式可知,该算法倾向于选择最短的句子,因为在这个连乘操作中,每个因子都是小于1的数,因子越多,最后的概率就越小。解决这个问题的方式,最后的概率值除以这个生成序列的单词数(记生成序列的单词数为$N$),这样比较的就是每个单词的平均概率大小。
此外,连乘因子较多时,可能会超过浮点数的最小值,可以考虑取对数来缓解这个问题。
参考文献:
吴恩达-《序列模型》课程